# -*- coding: utf-8 -*-
# Copyright (c) 2024-2024 Huawei Technologies Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import torch
import torch.nn as nn
from unittest.mock import patch, Mock
import warnings

from msmodelslim.quant.processor.anti_outlier.fused_linear import VirtualVModuleFromQKVFused


class TestVirtualVModule:
    """测试VirtualVModuleFromQKVFused的各种功能"""

    @pytest.fixture
    def qkv_module_with_bias(self):
        """创建带有偏置的QKV模块"""
        # MHA: 8个注意力头，每个头64维，总共512维
        # QKV融合: [Q(8*64), K(8*64), V(8*64)] = [512, 512, 512] -> 输出1536维
        qkv_module = nn.Linear(512, 1536, bias=True)
        return qkv_module

    @pytest.fixture
    def qkv_module_without_bias(self):
        """创建不带偏置的QKV模块"""
        qkv_module = nn.Linear(512, 1536, bias=False)
        return qkv_module

    def test_virtual_v_module_init_normal(self, qkv_module_with_bias):
        """测试初始化是否正确"""
        # 测试MHA情况
        virtual_v = VirtualVModuleFromQKVFused(
            qkv_module=qkv_module_with_bias,
            num_attention_heads=8,
            num_key_value_heads=8
        )
        
        # 验证初始化成功
        assert virtual_v.attention_type == "MHA"
        assert virtual_v.num_attention_heads == 8
        assert virtual_v.num_key_value_heads == 8
        assert virtual_v.qkv_module is qkv_module_with_bias
        
        # 验证V部分权重和偏置被正确提取
        assert virtual_v.weight is not None
        assert virtual_v.bias is not None
        assert virtual_v.weight.shape == (512, 512)  # V部分: 8*64=512维
        assert virtual_v.bias.shape == (512,)

    def test_virtual_v_module_update_weights(self, qkv_module_with_bias):
        """测试权重更新是否正确"""
        virtual_v = VirtualVModuleFromQKVFused(
            qkv_module=qkv_module_with_bias,
            num_attention_heads=8,
            num_key_value_heads=8
        )
        
        # 保存原始权重
        original_qkv_weight = virtual_v.qkv_module.weight.clone()
        original_qkv_bias = virtual_v.qkv_module.bias.clone()
        
        # 修改V部分的权重和偏置
        new_v_weight = torch.randn_like(virtual_v.weight)
        new_v_bias = torch.randn_like(virtual_v.bias)
        virtual_v.weight.data = new_v_weight
        virtual_v.bias.data = new_v_bias
        
        # 更新权重
        virtual_v.update_weights()
        
        # 验证QKV模块的V部分权重和偏置被正确更新
        head_dim = 64
        v_start = 8 * head_dim + 8 * head_dim  # Q + K
        v_end = v_start + 8 * head_dim  # + V
        
        assert torch.allclose(virtual_v.qkv_module.weight[v_start:v_end], new_v_weight)
        assert torch.allclose(virtual_v.qkv_module.bias[v_start:v_end], new_v_bias)
        
        # 验证Q和K部分没有被修改
        assert torch.allclose(virtual_v.qkv_module.weight[:v_start], original_qkv_weight[:v_start])
        assert torch.allclose(virtual_v.qkv_module.bias[:v_start], original_qkv_bias[:v_start])

    def test_virtual_v_module_determine_attention_type_mha(self):
        """测试注意力类型识别为MHA"""
        qkv_module = nn.Linear(512, 1536)
        virtual_v = VirtualVModuleFromQKVFused(
            qkv_module=qkv_module,
            num_attention_heads=8,
            num_key_value_heads=8
        )
        
        assert virtual_v.attention_type == "MHA"

    def test_virtual_v_module_determine_attention_type_mqa(self):
        """测试注意力类型识别为MQA"""
        qkv_module = nn.Linear(512, 1024)  # MQA: 8*64 + 1*64 + 1*64 = 640
        virtual_v = VirtualVModuleFromQKVFused(
            qkv_module=qkv_module,
            num_attention_heads=8,
            num_key_value_heads=1
        )
        
        assert virtual_v.attention_type == "MQA"

    def test_virtual_v_module_determine_attention_type_gqa(self):
        """测试注意力类型识别为GQA"""
        qkv_module = nn.Linear(512, 1280)  # GQA: 8*64 + 4*64 + 4*64 = 1024
        virtual_v = VirtualVModuleFromQKVFused(
            qkv_module=qkv_module,
            num_attention_heads=8,
            num_key_value_heads=4
        )
        
        assert virtual_v.attention_type == "GQA"

    def test_virtual_v_module_get_v_indices_mha(self):
        """测试V部分索引计算（MHA）"""
        qkv_module = nn.Linear(512, 1536)
        virtual_v = VirtualVModuleFromQKVFused(
            qkv_module=qkv_module,
            num_attention_heads=8,
            num_key_value_heads=8
        )
        
        head_dim = 64
        v_start, v_end = virtual_v._get_v_indices(head_dim)
        
        # MHA: Q(8*64) + K(8*64) + V(8*64)
        expected_v_start = 8 * head_dim + 8 * head_dim  # 1024
        expected_v_end = expected_v_start + 8 * head_dim  # 1536
        
        assert v_start == expected_v_start
        assert v_end == expected_v_end

    def test_virtual_v_module_get_v_indices_mqa(self):
        """测试V部分索引计算（MQA）"""
        qkv_module = nn.Linear(512, 1024)
        virtual_v = VirtualVModuleFromQKVFused(
            qkv_module=qkv_module,
            num_attention_heads=8,
            num_key_value_heads=1
        )
        
        head_dim = 64
        v_start, v_end = virtual_v._get_v_indices(head_dim)
        
        # MQA: Q(8*64) + K(1*64) + V(1*64)
        expected_v_start = 8 * head_dim + 1 * head_dim  # 576
        expected_v_end = expected_v_start + 1 * head_dim  # 640
        
        assert v_start == expected_v_start
        assert v_end == expected_v_end

    def test_virtual_v_module_get_v_indices_gqa(self):
        """测试V部分索引计算（GQA）"""
        qkv_module = nn.Linear(512, 1280)
        virtual_v = VirtualVModuleFromQKVFused(
            qkv_module=qkv_module,
            num_attention_heads=8,
            num_key_value_heads=4
        )
        
        head_dim = 64
        v_start, v_end = virtual_v._get_v_indices(head_dim)
        expected_v_start = 8 * head_dim + 4 * head_dim  # 768
        expected_v_end = expected_v_start + 4 * head_dim  # 1024
        
        assert v_start == expected_v_start
        assert v_end == expected_v_end

    def test_virtual_v_module_extract_v_weights_with_bias(self, qkv_module_with_bias):
        """测试提取V部分权重和偏置（有偏置）"""
        virtual_v = VirtualVModuleFromQKVFused(
            qkv_module=qkv_module_with_bias,
            num_attention_heads=8,
            num_key_value_heads=8
        )
        
        # 验证V部分权重和偏置被正确提取
        assert virtual_v.weight is not None
        assert virtual_v.bias is not None
        
        # 验证权重形状
        assert virtual_v.weight.shape == (512, 512)  # V部分: 8*64=512维
        assert virtual_v.bias.shape == (512,)
        
        # 验证提取的权重和偏置与原始QKV模块的V部分一致
        head_dim = 64
        v_start = 8 * head_dim + 8 * head_dim  # Q + K
        v_end = v_start + 8 * head_dim  # + V
        
        assert torch.allclose(virtual_v.weight, qkv_module_with_bias.weight[v_start:v_end])
        assert torch.allclose(virtual_v.bias, qkv_module_with_bias.bias[v_start:v_end])

    def test_virtual_v_module_extract_v_weights_without_bias(self, qkv_module_without_bias):
        """测试提取V部分权重和偏置（无偏置）"""
        virtual_v = VirtualVModuleFromQKVFused(
            qkv_module=qkv_module_without_bias,
            num_attention_heads=8,
            num_key_value_heads=8
        )
        
        # 验证V部分权重被正确提取，偏置为None
        assert virtual_v.weight is not None
        assert virtual_v.bias is None
        
        # 验证权重形状
        assert virtual_v.weight.shape == (512, 512)  # V部分: 8*64=512维
        
        # 验证提取的权重与原始QKV模块的V部分一致
        head_dim = 64
        v_start = 8 * head_dim + 8 * head_dim  # Q + K
        v_end = v_start + 8 * head_dim  # + V
        
        assert torch.allclose(virtual_v.weight, qkv_module_without_bias.weight[v_start:v_end])

    def test_virtual_v_module_update_weights_without_bias(self, qkv_module_without_bias):
        """测试更新权重（无原始偏置）"""
        virtual_v = VirtualVModuleFromQKVFused(
            qkv_module=qkv_module_without_bias,
            num_attention_heads=8,
            num_key_value_heads=8
        )
        
        # 创建新的偏置
        new_v_bias = torch.randn(512)
        virtual_v.bias = nn.Parameter(new_v_bias)
        
        # 修改V部分的权重
        new_v_weight = torch.randn_like(virtual_v.weight)
        virtual_v.weight.data = new_v_weight
        
        # 更新权重
        virtual_v.update_weights()
        
        # 验证QKV模块现在有了偏置
        assert virtual_v.qkv_module.bias is not None
        
        # 验证V部分的权重和偏置被正确更新
        head_dim = 64
        v_start = 8 * head_dim + 8 * head_dim  # Q + K
        v_end = v_start + 8 * head_dim  # + V
        
        assert torch.allclose(virtual_v.qkv_module.weight[v_start:v_end], new_v_weight)
        assert torch.allclose(virtual_v.qkv_module.bias[v_start:v_end], new_v_bias)
        
        # 验证Q和K部分的偏置为0
        assert torch.allclose(virtual_v.qkv_module.bias[:v_start], torch.zeros_like(virtual_v.qkv_module.bias[:v_start]))

    def test_virtual_v_module_update_weights_with_bias(self, qkv_module_with_bias):
        """测试更新权重（有原始偏置）"""
        virtual_v = VirtualVModuleFromQKVFused(
            qkv_module=qkv_module_with_bias,
            num_attention_heads=8,
            num_key_value_heads=8
        )
        
        # 保存原始偏置
        original_bias = virtual_v.qkv_module.bias.clone()
        
        # 修改V部分的权重和偏置
        new_v_weight = torch.randn_like(virtual_v.weight)
        new_v_bias = torch.randn_like(virtual_v.bias)
        virtual_v.weight.data = new_v_weight
        virtual_v.bias.data = new_v_bias
        
        # 更新权重
        virtual_v.update_weights()
        
        # 验证V部分的权重和偏置被正确更新
        head_dim = 64
        v_start = 8 * head_dim + 8 * head_dim  # Q + K
        v_end = v_start + 8 * head_dim  # + V
        
        assert torch.allclose(virtual_v.qkv_module.weight[v_start:v_end], new_v_weight)
        assert torch.allclose(virtual_v.qkv_module.bias[v_start:v_end], new_v_bias)
        
        # 验证Q和K部分的偏置没有被修改
        assert torch.allclose(virtual_v.qkv_module.bias[:v_start], original_bias[:v_start])


if __name__ == "__main__":
    pytest.main([__file__])